{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "18zNHCLg5C3t"
   },
   "outputs": [],
   "source": [
    "!pip install datasets\n",
    "!pip install evaluate\n",
    "!pip install transformers[torch]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VzKzjJJb542a"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import glob\n",
    "import shutil\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from datasets import load_dataset, DatasetDict\n",
    "import evaluate\n",
    "from evaluate import evaluator\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer, pipeline\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FfOmOPtnuMOw"
   },
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"csv\", data_files=\"./style_train.csv\", split='train')\n",
    "\n",
    "# Assuming 'label' is the column name for your labels\n",
    "label_counts = Counter(dataset['label'])\n",
    "\n",
    "# Find the smallest class size\n",
    "min_class_size = min(label_counts.values())\n",
    "\n",
    "# Function to undersample the dataset\n",
    "def undersample_dataset(dataset, min_class_size):\n",
    "    # Shuffle the dataset to ensure random sampling\n",
    "    dataset = dataset.shuffle(seed=42)\n",
    "\n",
    "    # Initialize a list to hold the undersampled indices for each class\n",
    "    undersampled_indices = []\n",
    "\n",
    "    # Track the number of added samples per class\n",
    "    samples_per_class = Counter()\n",
    "\n",
    "    for index, example in enumerate(dataset):\n",
    "        label = example['label']\n",
    "        if samples_per_class[label] < min_class_size:\n",
    "            undersampled_indices.append(index)\n",
    "            samples_per_class[label] += 1\n",
    "\n",
    "        # Stop if all classes have reached the min_class_size\n",
    "        if all(count == min_class_size for count in samples_per_class.values()):\n",
    "            break\n",
    "\n",
    "    # Subset the dataset to the undersampled indices\n",
    "    undersampled_dataset = dataset.select(undersampled_indices)\n",
    "    return undersampled_dataset\n",
    "\n",
    "dataset = undersample_dataset(dataset, min_class_size)\n",
    "\n",
    "train_testvalid = dataset.train_test_split(test_size=0.3)\n",
    "test_valid = train_testvalid[\"test\"].train_test_split(test_size=0.5)\n",
    "dataset = DatasetDict({\n",
    "    \"train\": train_testvalid[\"train\"],\n",
    "    \"test\": test_valid[\"test\"],\n",
    "    \"valid\": test_valid[\"train\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7MdWEZVV87bL"
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"KB/bert-base-swedish-cased\")\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    return tokenizer(examples[\"headl_text\"], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_data = dataset.map(preprocess_function, batched=True)\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "57JqV4KPWH-t"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(\"KB/bert-base-swedish-cased\", num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uX30SRXX9BV9"
   },
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./results\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    num_train_epochs=25,\n",
    "    weight_decay=0.01,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_data[\"train\"],\n",
    "    eval_dataset=tokenized_data[\"valid\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9oxwZ850-HI-"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score\n",
    "predictions = trainer.predict(tokenized_data[\"test\"])\n",
    "preds = np.argmax(predictions.predictions, axis=-1)\n",
    "recall = evaluate.load(\"recall\")\n",
    "f1 = evaluate.load(\"f1\")\n",
    "precision = evaluate.load(\"precision\")\n",
    "f1 = f1.compute(predictions=preds, references=predictions.label_ids)\n",
    "recall = recall.compute(predictions=preds, references=predictions.label_ids)\n",
    "precision = precision.compute(predictions=preds, references=predictions.label_ids)\n",
    "balanced_accuracy_result = balanced_accuracy_score(preds, predictions.label_ids)\n",
    "print(precision)\n",
    "print(recall)\n",
    "print(f1)\n",
    "print(balanced_accuracy_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xRjWOAU4mWSU"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "conf_matrix = confusion_matrix(predictions.label_ids, preds)\n",
    "\n",
    "print(\"Confusion Matrix:\\n\", conf_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AbJlg1qWfzYL"
   },
   "outputs": [],
   "source": [
    "trainer.save_model(\"./fine_tuned_model_style\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fvZ0V-CLfzat"
   },
   "outputs": [],
   "source": [
    "# get articles for prediction part 1\n",
    "articles = pd.read_csv('./all_articles_for_pred_1.csv')\n",
    "\n",
    "# tranformers wants a text list as input\n",
    "text = articles['headl_text'].values.tolist()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"KB/bert-base-swedish-cased\")\n",
    "\n",
    "# load model\n",
    "fine_tuned_model = AutoModelForSequenceClassification.from_pretrained(\"./fine_tuned_model_style\", num_labels=2)\n",
    "\n",
    "# set up and run pipeline for inference\n",
    "clf = pipeline(task=\"text-classification\", model=fine_tuned_model, tokenizer=tokenizer, truncation=True, device=0, max_length=512)\n",
    "answer = clf(text)\n",
    "\n",
    "# export\n",
    "answer = pd.DataFrame(answer)\n",
    "predictions = pd.concat([articles.reset_index(drop=True), answer], axis=1)\n",
    "predictions = predictions[[\"id\", \"label\", \"score\"]]\n",
    "predictions.to_csv('./predictions_style_1.csv', sep=',', index=False, encoding='utf-8')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K7X2opE-XA2t"
   },
   "outputs": [],
   "source": [
    "# get articles for prediction part 2\n",
    "articles = pd.read_csv('./all_articles_for_pred_2.csv')\n",
    "\n",
    "# tranformers wants a text list as input\n",
    "text = articles['headl_text'].values.tolist()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"KB/bert-base-swedish-cased\")\n",
    "\n",
    "# load model\n",
    "fine_tuned_model = AutoModelForSequenceClassification.from_pretrained(\"./fine_tuned_model_style\", num_labels=2)\n",
    "\n",
    "# set up and run pipeline for inference\n",
    "clf = pipeline(task=\"text-classification\", model=fine_tuned_model, tokenizer=tokenizer, truncation=True, device=0, max_length=512)\n",
    "answer = clf(text)\n",
    "\n",
    "# export\n",
    "answer = pd.DataFrame(answer)\n",
    "predictions = pd.concat([articles.reset_index(drop=True), answer], axis=1)\n",
    "predictions = predictions[[\"id\", \"label\", \"score\"]]\n",
    "predictions.to_csv('./predictions_style_2.csv', sep=',', index=False, encoding='utf-8')\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyMBQ/tBu1usOMZFVfC0ZlSs",
   "gpuType": "L4",
   "machine_shape": "hm",
   "provenance": [
    {
     "file_id": "18WWqSHctzuD3CRIY_1krw64Oma7c1oAu",
     "timestamp": 1697881380924
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
